# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Any

import pytest
from mistral_common.exceptions import InvalidMessageStructureException
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy

from vllm.transformers_utils.tokenizers.mistral import (
    MistralTokenizer,
    _prepare_apply_chat_template_tools_and_messages,
)


@pytest.mark.parametrize(
    "openai_request,expected_mistral_output",
    [
        (
            {
                "messages": [
                    {
                        "role": "user",
                        "content": "What is the current local date and time?",
                    }
                ],
                "tools": [
                    {
                        "type": "function",
                        "function": {
                            "description": "Fetch the current local date and time.",
                            "name": "get_current_time",
                        },
                    }
                ],
            },
            (
                [
                    {
                        "role": "user",
                        "content": "What is the current local date and time?",
                    }
                ],
                [
                    {
                        "type": "function",
                        "function": {
                            "description": "Fetch the current local date and time.",
                            "name": "get_current_time",
                            "parameters": {},
                        },
                    }
                ],
            ),
        ),
        (
            {
                "messages": [
                    {
                        "role": "user",
                        "content": "What is the current local date and time?",
                    }
                ],
                "tools": [
                    {
                        "type": "function",
                        "function": {
                            "description": "Fetch the current local date and time.",
                            "name": "get_current_time",
                            "parameters": {},
                        },
                    }
                ],
            },
            (
                [
                    {
                        "role": "user",
                        "content": "What is the current local date and time?",
                    }
                ],
                [
                    {
                        "type": "function",
                        "function": {
                            "description": "Fetch the current local date and time.",
                            "name": "get_current_time",
                            "parameters": {},
                        },
                    }
                ],
            ),
        ),
    ],
)
def test_prepare_apply_chat_template_tools_and_messages(
    openai_request, expected_mistral_output
):
    actual_request = _prepare_apply_chat_template_tools_and_messages(
        openai_request["messages"], openai_request["tools"]
    )
    assert actual_request == expected_mistral_output


# Tool use with list content and reasoning
@pytest.mark.parametrize(
    "openai_request,expected_mistral_output",
    [
        (
            {
                "messages": [
                    {
                        "role": "user",
                        "content": "What's the weather in Paris?",
                    },
                    {
                        "role": "assistant",
                        "reasoning": None,
                        "content": None,
                        "tool_calls": [
                            {
                                "id": "call123",
                                "type": "function",
                                "function": {
                                    "name": "get_weather",
                                    "arguments": '{"city": "Paris"}',
                                },
                            }
                        ],
                    },
                    {
                        "role": "tool",
                        "content": [{"type": "text", "text": "Rainy"}],
                        "name": "get_weather",
                        "tool_call_id": "call123",
                    },
                ],
                "tools": [
                    {
                        "type": "function",
                        "function": {
                            "name": "get_weather",
                            "description": "Gets the current weather in a city.",
                            "parameters": {
                                "type": "object",
                                "properties": {
                                    "city": {
                                        "type": "string",
                                        "description": "The city name",
                                    }
                                },
                                "required": ["city"],
                            },
                        },
                    }
                ],
            },
            (
                [
                    {
                        "role": "user",
                        "content": "What's the weather in Paris?",
                    },
                    {
                        "role": "assistant",
                        "content": None,
                        "tool_calls": [
                            {
                                "id": "call123",
                                "type": "function",
                                "function": {
                                    "name": "get_weather",
                                    "arguments": '{"city": "Paris"}',
                                },
                            }
                        ],
                    },
                    {
                        "role": "tool",
                        "content": [{"type": "text", "text": "Rainy"}],
                        "name": "get_weather",
                        "tool_call_id": "call123",
                    },
                ],
                [
                    {
                        "type": "function",
                        "function": {
                            "name": "get_weather",
                            "description": "Gets the current weather in a city.",
                            "parameters": {
                                "type": "object",
                                "properties": {
                                    "city": {
                                        "type": "string",
                                        "description": "The city name",
                                    }
                                },
                                "required": ["city"],
                            },
                        },
                    }
                ],
            ),
        )
    ],
)
def test_prepare_apply_chat_template_tools_and_messages_list_content(
    openai_request, expected_mistral_output
):
    actual_request = _prepare_apply_chat_template_tools_and_messages(
        openai_request["messages"], openai_request["tools"]
    )
    assert actual_request == expected_mistral_output


def test_prepare_apply_chat_template_generation_prompt_and_continue():
    messages = [{"role": "assistant", "content": "Hello"}]
    tools: list[dict[str, Any]] = []
    with pytest.raises(ValueError):
        _prepare_apply_chat_template_tools_and_messages(
            messages, tools, add_generation_prompt=True
        )

    messages = [{"role": "user", "content": "Hello"}]
    out_messages, _ = _prepare_apply_chat_template_tools_and_messages(
        messages, tools, add_generation_prompt=True
    )
    assert out_messages == [{"role": "user", "content": "Hello"}]

    with pytest.raises(ValueError):
        _prepare_apply_chat_template_tools_and_messages(
            messages, tools, add_generation_prompt=True, continue_final_message=True
        )

    messages = [{"role": "assistant", "content": "Hello"}]
    out_messages, _ = _prepare_apply_chat_template_tools_and_messages(
        messages, tools, add_generation_prompt=False, continue_final_message=True
    )
    assert out_messages == [{"role": "assistant", "content": "Hello"}]

    messages = [{"role": "user", "content": "Hello"}]
    with pytest.raises(ValueError):
        _prepare_apply_chat_template_tools_and_messages(
            messages, tools, add_generation_prompt=False, continue_final_message=True
        )


@pytest.fixture(scope="module")
def mistral_tokenizer(request) -> MistralTokenizer:
    return MistralTokenizer.from_pretrained(request.param)


@pytest.mark.parametrize(
    "mistral_tokenizer",
    ["mistralai/Mistral-7B-Instruct-v0.3", "mistralai/Magistral-Small-2509"],
    indirect=True,
)
class TestMistralTokenizer:
    def test_all_special_tokens(self, mistral_tokenizer: MistralTokenizer):
        attributes = [
            mistral_tokenizer.all_special_tokens,
            mistral_tokenizer.all_special_tokens_extended,
        ]

        for attribute in attributes:
            if mistral_tokenizer.is_tekken:
                assert attribute == [
                    "<unk>",
                    "<s>",
                    "</s>",
                    "[INST]",
                    "[/INST]",
                    "[AVAILABLE_TOOLS]",
                    "[/AVAILABLE_TOOLS]",
                    "[TOOL_RESULTS]",
                    "[/TOOL_RESULTS]",
                    "[TOOL_CALLS]",
                    "[IMG]",
                    "<pad>",
                    "[IMG_BREAK]",
                    "[IMG_END]",
                    "[PREFIX]",
                    "[MIDDLE]",
                    "[SUFFIX]",
                    "[SYSTEM_PROMPT]",
                    "[/SYSTEM_PROMPT]",
                    "[TOOL_CONTENT]",
                ] + [f"<SPECIAL_{i}>" for i in range(20, 32)] + [
                    "[ARGS]",
                    "[CALL_ID]",
                    "[THINK]",
                    "[/THINK]",
                ] + [f"<SPECIAL_{i}>" for i in range(36, 1000)]
            else:
                assert attribute == [
                    "<s>",
                    "</s>",
                    "[INST]",
                    "[/INST]",
                    "[TOOL_CALLS]",
                    "[AVAILABLE_TOOLS]",
                    "[/AVAILABLE_TOOLS]",
                    "[TOOL_RESULTS]",
                    "[/TOOL_RESULTS]",
                ] + [f"[control_{i}]" for i in range(8, 769)]

    def get_vocab(self, mistral_tokenizer: MistralTokenizer):
        assert (
            mistral_tokenizer.get_vocab()
            == mistral_tokenizer.transformers_tokenizer.get_vocab()
        )

    def test_get_added_vocab(self, mistral_tokenizer: MistralTokenizer):
        assert mistral_tokenizer.get_added_vocab() == {}

    def test_encode_one(self, mistral_tokenizer: MistralTokenizer):
        token_ids = (
            [22177, 4304, 2662] if mistral_tokenizer.is_tekken else [23325, 2294, 1686]
        )

        assert mistral_tokenizer.encode_one("Hello world !") == token_ids
        assert mistral_tokenizer.encode_one("Hello world !", max_length=1) == token_ids
        assert (
            mistral_tokenizer.encode_one("Hello world !", truncation=True, max_length=1)
            == token_ids[:-2]
        )
        assert (
            mistral_tokenizer.encode_one(
                "Hello world !", truncation=False, max_length=1
            )
            == token_ids
        )

    def test_encode(self, mistral_tokenizer: MistralTokenizer):
        token_ids = (
            [1, 22177, 4304, 2662]
            if mistral_tokenizer.is_tekken
            else [1, 23325, 2294, 1686]
        )

        assert mistral_tokenizer.encode("Hello world !") == token_ids
        assert mistral_tokenizer.encode("Hello world !", max_length=3) == token_ids[:-1]
        assert (
            mistral_tokenizer.encode("Hello world !", truncation=True, max_length=3)
            == token_ids[:-1]
        )
        assert (
            mistral_tokenizer.encode("Hello world !", truncation=False, max_length=3)
            == token_ids
        )

        assert (
            mistral_tokenizer.encode("Hello world !", add_special_tokens=True)
            == token_ids
        )
        assert (
            mistral_tokenizer.encode(
                "Hello world !", add_special_tokens=True, max_length=3
            )
            == token_ids[:-1]
        )
        assert (
            mistral_tokenizer.encode(
                "Hello world !", add_special_tokens=True, truncation=False, max_length=3
            )
            == token_ids
        )
        assert (
            mistral_tokenizer.encode("Hello world !", add_special_tokens=False)
            == token_ids[1:]
        )

    @pytest.mark.parametrize(
        "openai_request,add_generation_prompt,continue_final_message,expected_output,decoded_expected_output",
        [
            (
                {
                    "messages": [
                        {
                            "role": "user",
                            "content": "Hello world !",
                        }
                    ],
                },
                True,
                False,
                ([1, 3, 23325, 2294, 1686, 4], [1, 3, 22177, 4304, 2662, 4]),
                ("<s>[INST]▁Hello▁world▁![/INST]", ("<s>[INST]Hello world ![/INST]")),
            ),
            (
                {
                    "messages": [
                        {
                            "role": "system",
                            "content": "I am an AI",
                        },
                        {
                            "role": "user",
                            "content": "Hello world !",
                        },
                    ],
                },
                True,
                False,
                (
                    [1, 3, 1083, 1605, 1164, 16875, 781, 781, 16998, 2294, 1686, 4],
                    [1, 17, 1073, 1855, 1420, 26554, 18, 3, 22177, 4304, 2662, 4],
                ),
                (
                    "<s>[INST]▁I▁am▁an▁AI<0x0A><0x0A>Hello▁world▁![/INST]",
                    (
                        "<s>[SYSTEM_PROMPT]I am an AI[/SYSTEM_PROMPT][INST]Hello world ![/INST]"  # noqa: E501
                    ),
                ),
            ),
            (
                {
                    "messages": [
                        {
                            "role": "system",
                            "content": "I am an AI",
                        },
                        {
                            "role": "user",
                            "content": "Hello world !",
                        },
                    ],
                    "tools": [
                        {
                            "type": "function",
                            "function": {
                                "name": "get_weather",
                                "description": "Gets the current weather in a city.",
                                "parameters": {
                                    "type": "object",
                                    "properties": {
                                        "city": {
                                            "type": "string",
                                            "description": "The city name",
                                        }
                                    },
                                    "required": ["city"],
                                },
                            },
                        }
                    ],
                },
                True,
                False,
                (
                    [
                        1,
                        6,
                        1501,
                        7567,
                        1891,
                        2032,
                        1113,
                        3396,
                        1316,
                        1113,
                        3396,
                        2032,
                        10598,
                        1629,
                        2032,
                        1113,
                        1295,
                        29498,
                        1537,
                        1991,
                        1316,
                        1113,
                        7286,
                        2032,
                        1113,
                        2226,
                        29481,
                        1040,
                        2636,
                        8854,
                        1065,
                        1032,
                        3758,
                        9959,
                        1113,
                        12206,
                        2032,
                        10598,
                        1891,
                        2032,
                        1113,
                        3582,
                        1316,
                        1113,
                        11491,
                        2032,
                        10598,
                        19141,
                        2032,
                        10598,
                        1891,
                        2032,
                        1113,
                        2195,
                        1316,
                        1113,
                        7286,
                        2032,
                        1113,
                        1782,
                        3758,
                        1909,
                        29507,
                        11549,
                        1113,
                        11661,
                        2032,
                        8135,
                        19141,
                        3010,
                        1743,
                        10925,
                        7,
                        3,
                        1083,
                        1605,
                        1164,
                        16875,
                        781,
                        781,
                        16998,
                        2294,
                        1686,
                        4,
                    ],
                    [
                        1,
                        17,
                        1073,
                        1855,
                        1420,
                        26554,
                        18,
                        5,
                        1091,
                        19227,
                        4994,
                        2811,
                        1429,
                        5165,
                        1897,
                        1429,
                        5165,
                        2811,
                        16753,
                        2391,
                        2811,
                        1429,
                        1689,
                        1095,
                        45629,
                        1897,
                        1429,
                        14653,
                        2811,
                        1429,
                        1071,
                        3083,
                        1278,
                        3519,
                        17253,
                        1294,
                        1261,
                        5970,
                        39249,
                        1429,
                        26204,
                        2811,
                        16753,
                        4994,
                        2811,
                        1429,
                        6371,
                        1897,
                        1429,
                        48649,
                        2811,
                        16753,
                        29363,
                        2811,
                        16753,
                        4994,
                        2811,
                        1429,
                        3607,
                        1897,
                        1429,
                        14653,
                        2811,
                        1429,
                        1784,
                        5970,
                        2564,
                        1034,
                        47579,
                        1429,
                        15760,
                        2811,
                        12161,
                        29363,
                        4964,
                        2821,
                        27028,
                        6,
                        3,
                        22177,
                        4304,
                        2662,
                        4,
                    ],
                ),
                (
                    '<s>[AVAILABLE_TOOLS]▁[{"type":▁"function",▁"function":▁{"name":▁"get_weather",▁"description":▁"Gets▁the▁current▁weather▁in▁a▁city.",▁"parameters":▁{"type":▁"object",▁"properties":▁{"city":▁{"type":▁"string",▁"description":▁"The▁city▁name"}},▁"required":▁["city"]}}}][/AVAILABLE_TOOLS][INST]▁I▁am▁an▁AI<0x0A><0x0A>Hello▁world▁![/INST]',
                    (
                        '<s>[SYSTEM_PROMPT]I am an AI[/SYSTEM_PROMPT][AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}][/AVAILABLE_TOOLS][INST]Hello world ![/INST]'  # noqa: E501
                    ),
                ),
            ),
            (
                {
                    "messages": [
                        {
                            "role": "system",
                            "content": "I am an AI",
                        },
                        {
                            "role": "user",
                            "content": "Hello world !",
                        },
                        {
                            "role": "assistant",
                            "content": "",
                            "tool_calls": [
                                {
                                    "id": "123456789",
                                    "type": "function",
                                    "function": {
                                        "name": "get_weather",
                                        "arguments": '{"city": "Paris"}',
                                    },
                                }
                            ],
                        },
                        {
                            "role": "tool",
                            "tool_call_id": "123456789",
                            "content": '{"temperature": 20, "unit": "celsius"}',
                        },
                    ],
                    "tools": [
                        {
                            "type": "function",
                            "function": {
                                "name": "get_weather",
                                "description": "Gets the current weather in a city.",
                                "parameters": {
                                    "type": "object",
                                    "properties": {
                                        "city": {
                                            "type": "string",
                                            "description": "The city name",
                                        }
                                    },
                                    "required": ["city"],
                                },
                            },
                        }
                    ],
                },
                True,
                False,
                (
                    [
                        1,
                        6,
                        1501,
                        7567,
                        1891,
                        2032,
                        1113,
                        3396,
                        1316,
                        1113,
                        3396,
                        2032,
                        10598,
                        1629,
                        2032,
                        1113,
                        1295,
                        29498,
                        1537,
                        1991,
                        1316,
                        1113,
                        7286,
                        2032,
                        1113,
                        2226,
                        29481,
                        1040,
                        2636,
                        8854,
                        1065,
                        1032,
                        3758,
                        9959,
                        1113,
                        12206,
                        2032,
                        10598,
                        1891,
                        2032,
                        1113,
                        3582,
                        1316,
                        1113,
                        11491,
                        2032,
                        10598,
                        19141,
                        2032,
                        10598,
                        1891,
                        2032,
                        1113,
                        2195,
                        1316,
                        1113,
                        7286,
                        2032,
                        1113,
                        1782,
                        3758,
                        1909,
                        29507,
                        11549,
                        1113,
                        11661,
                        2032,
                        8135,
                        19141,
                        3010,
                        1743,
                        10925,
                        7,
                        3,
                        1083,
                        1605,
                        1164,
                        16875,
                        781,
                        781,
                        16998,
                        2294,
                        1686,
                        4,
                        5,
                        1501,
                        7567,
                        1629,
                        2032,
                        1113,
                        1295,
                        29498,
                        1537,
                        1991,
                        1316,
                        1113,
                        17452,
                        2032,
                        10598,
                        19141,
                        2032,
                        1113,
                        4684,
                        1046,
                        8474,
                        1113,
                        1081,
                        2032,
                        1113,
                        29508,
                        29518,
                        29538,
                        29549,
                        29550,
                        29552,
                        29555,
                        29551,
                        29542,
                        29507,
                        10925,
                        2,
                        8,
                        10598,
                        4557,
                        2032,
                        10598,
                        29475,
                        17329,
                        2032,
                        29473,
                        29518,
                        29502,
                        29493,
                        1113,
                        6074,
                        2032,
                        1113,
                        29485,
                        1958,
                        3938,
                        8474,
                        1113,
                        3613,
                        29498,
                        1081,
                        2032,
                        1113,
                        29508,
                        29518,
                        29538,
                        29549,
                        29550,
                        29552,
                        29555,
                        29551,
                        29542,
                        18163,
                        9,
                    ],
                    [
                        1,
                        17,
                        1073,
                        1855,
                        1420,
                        26554,
                        18,
                        5,
                        1091,
                        19227,
                        4994,
                        2811,
                        1429,
                        5165,
                        1897,
                        1429,
                        5165,
                        2811,
                        16753,
                        2391,
                        2811,
                        1429,
                        1689,
                        1095,
                        45629,
                        1897,
                        1429,
                        14653,
                        2811,
                        1429,
                        1071,
                        3083,
                        1278,
                        3519,
                        17253,
                        1294,
                        1261,
                        5970,
                        39249,
                        1429,
                        26204,
                        2811,
                        16753,
                        4994,
                        2811,
                        1429,
                        6371,
                        1897,
                        1429,
                        48649,
                        2811,
                        16753,
                        29363,
                        2811,
                        16753,
                        4994,
                        2811,
                        1429,
                        3607,
                        1897,
                        1429,
                        14653,
                        2811,
                        1429,
                        1784,
                        5970,
                        2564,
                        1034,
                        47579,
                        1429,
                        15760,
                        2811,
                        12161,
                        29363,
                        4964,
                        2821,
                        27028,
                        6,
                        3,
                        22177,
                        4304,
                        2662,
                        4,
                        9,
                        1689,
                        1095,
                        45629,
                        32,
                        19227,
                        29363,
                        2811,
                        1429,
                        42572,
                        46005,
                        2,
                        7,
                        19227,
                        113824,
                        2811,
                        1032,
                        1050,
                        1048,
                        1044,
                        1429,
                        8979,
                        2811,
                        1429,
                        1099,
                        79092,
                        46005,
                        8,
                    ],
                ),
                (
                    '<s>[AVAILABLE_TOOLS]▁[{"type":▁"function",▁"function":▁{"name":▁"get_weather",▁"description":▁"Gets▁the▁current▁weather▁in▁a▁city.",▁"parameters":▁{"type":▁"object",▁"properties":▁{"city":▁{"type":▁"string",▁"description":▁"The▁city▁name"}},▁"required":▁["city"]}}}][/AVAILABLE_TOOLS][INST]▁I▁am▁an▁AI<0x0A><0x0A>Hello▁world▁![/INST][TOOL_CALLS]▁[{"name":▁"get_weather",▁"arguments":▁{"city":▁"Paris"},▁"id":▁"123456789"}]</s>[TOOL_RESULTS]▁{"content":▁{"temperature":▁20,▁"unit":▁"celsius"},▁"call_id":▁"123456789"}[/TOOL_RESULTS]',
                    (
                        '<s>[SYSTEM_PROMPT]I am an AI[/SYSTEM_PROMPT][AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}][/AVAILABLE_TOOLS][INST]Hello world ![/INST][TOOL_CALLS]get_weather[ARGS]{"city": "Paris"}</s>[TOOL_RESULTS]{"temperature": 20, "unit": "celsius"}[/TOOL_RESULTS]'  # noqa: E501
                    ),
                ),
            ),
            (
                {
                    "messages": [
                        {
                            "role": "user",
                            "content": "Hello world !",
                        },
                        {
                            "role": "assistant",
                            "content": "Hello ",
                        },
                    ],
                },
                False,
                True,
                (
                    [1, 3, 23325, 2294, 1686, 4, 23325],
                    [1, 3, 22177, 4304, 2662, 4, 22177, 2],
                ),
                (
                    "<s>[INST]▁Hello▁world▁![/INST]▁Hello",
                    ("<s>[INST]Hello world ![/INST]Hello</s>"),
                ),
            ),
        ],
    )
    def test_apply_chat_template(
        self,
        mistral_tokenizer: MistralTokenizer,
        openai_request: dict[str, Any],
        add_generation_prompt: bool,
        continue_final_message: bool,
        expected_output: tuple[list[int], list[int]],
        decoded_expected_output: tuple[str, str],
    ):
        actual_output = mistral_tokenizer.apply_chat_template(
            openai_request["messages"],
            tools=openai_request.get("tools", []),
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
        )
        decoded_actual_output = mistral_tokenizer.tokenizer.decode(
            actual_output, SpecialTokenPolicy.KEEP
        )

        assert actual_output == expected_output[mistral_tokenizer.is_tekken]
        assert (
            decoded_actual_output
            == decoded_expected_output[mistral_tokenizer.is_tekken]
        )

    def test_apply_chat_template_error(self, mistral_tokenizer: MistralTokenizer):
        messages = [{"role": "user", "content": "Hello world !"}]

        with pytest.raises(ValueError):
            mistral_tokenizer.apply_chat_template(
                messages,
                tools=[],
                add_generation_prompt=True,
                continue_final_message=True,
            )

        with pytest.raises(ValueError):
            mistral_tokenizer.apply_chat_template(
                messages,
                tools=[],
                add_generation_prompt=False,
                continue_final_message=True,
            )

        messages = [
            {"role": "user", "content": "Hello world !"},
            {"role": "assistant", "content": "Hello "},
        ]
        with pytest.raises(ValueError):
            mistral_tokenizer.apply_chat_template(
                messages,
                tools=[],
                add_generation_prompt=True,
                continue_final_message=False,
            )

        messages = [
            {"role": "user", "content": "Hello world !"},
            {"role": "assistant", "content": "Hello "},
        ]
        with pytest.raises(InvalidMessageStructureException):
            mistral_tokenizer.apply_chat_template(
                messages,
                tools=[],
                add_generation_prompt=False,
                continue_final_message=False,
            )

    @pytest.mark.parametrize(
        "skip_special_tokens,expected_tokens",
        (
            (
                False,
                (
                    "<s>[INST]▁Hello▁world▁![/INST]▁Hello</s>",
                    "<s>[INST]Hello world ![/INST]Hello</s>",
                ),
            ),
            (True, ("Hello world ! Hello", "Hello world !Hello")),
        ),
    )
    def test_decode(
        self,
        mistral_tokenizer: MistralTokenizer,
        skip_special_tokens: bool,
        expected_tokens: tuple[str, str],
    ):
        ids = (
            [1, 3, 23325, 2294, 1686, 4, 23325, 2],
            [1, 3, 22177, 4304, 2662, 4, 22177, 2],
        )
        assert (
            mistral_tokenizer.decode(
                ids[mistral_tokenizer.is_tekken],
                skip_special_tokens=skip_special_tokens,
            )
            == expected_tokens[mistral_tokenizer.is_tekken]
        )

    def test_decode_int(
        self,
        mistral_tokenizer: MistralTokenizer,
    ):
        ids = 1
        assert (
            mistral_tokenizer.decode(
                ids,
                skip_special_tokens=False,
            )
            == "<s>"
        )

    def test_convert_tokens_to_string(self, mistral_tokenizer: MistralTokenizer):
        tokens = (
            [
                "<s>",
                "[AVAILABLE_TOOLS]",
                "▁[",
                '{"',
                "type",
                '":',
                '▁"',
                "function",
                '",',
                '▁"',
                "function",
                '":',
                '▁{"',
                "name",
                '":',
                '▁"',
                "get",
                "_",
                "we",
                "ather",
                '",',
                '▁"',
                "description",
                '":',
                '▁"',
                "Get",
                "s",
                "▁the",
                "▁current",
                "▁weather",
                "▁in",
                "▁a",
                "▁city",
                '.",',
                '▁"',
                "parameters",
                '":',
                '▁{"',
                "type",
                '":',
                '▁"',
                "object",
                '",',
                '▁"',
                "properties",
                '":',
                '▁{"',
                "city",
                '":',
                '▁{"',
                "type",
                '":',
                '▁"',
                "string",
                '",',
                '▁"',
                "description",
                '":',
                '▁"',
                "The",
                "▁city",
                "▁name",
                '"',
                "}},",
                '▁"',
                "required",
                '":',
                '▁["',
                "city",
                '"]',
                "}}",
                "}]",
                "[/AVAILABLE_TOOLS]",
                "[INST]",
                "▁I",
                "▁am",
                "▁an",
                "▁AI",
                "<0x0A>",
                "<0x0A>",
                "Hello",
                "▁world",
                "▁!",
                "[/INST]",
                "[TOOL_CALLS]",
                "▁[",
                '{"',
                "name",
                '":',
                '▁"',
                "get",
                "_",
                "we",
                "ather",
                '",',
                '▁"',
                "arguments",
                '":',
                '▁{"',
                "city",
                '":',
                '▁"',
                "Par",
                "is",
                '"},',
                '▁"',
                "id",
                '":',
                '▁"',
                "1",
                "2",
                "3",
                "4",
                "5",
                "6",
                "7",
                "8",
                "9",
                '"',
                "}]",
                "</s>",
                "[TOOL_RESULTS]",
                '▁{"',
                "content",
                '":',
                '▁{"',
                "t",
                "emperature",
                '":',
                "▁",
                "2",
                "0",
                ",",
                '▁"',
                "unit",
                '":',
                '▁"',
                "c",
                "els",
                "ius",
                '"},',
                '▁"',
                "call",
                "_",
                "id",
                '":',
                '▁"',
                "1",
                "2",
                "3",
                "4",
                "5",
                "6",
                "7",
                "8",
                "9",
                '"}',
                "[/TOOL_RESULTS]",
            ],
            [
                "<s>",
                "[SYSTEM_PROMPT]",
                "I",
                " am",
                " an",
                " AI",
                "[/SYSTEM_PROMPT]",
                "[AVAILABLE_TOOLS]",
                "[",
                '{"',
                "type",
                '":',
                ' "',
                "function",
                '",',
                ' "',
                "function",
                '":',
                ' {"',
                "name",
                '":',
                ' "',
                "get",
                "_",
                "weather",
                '",',
                ' "',
                "description",
                '":',
                ' "',
                "G",
                "ets",
                " the",
                " current",
                " weather",
                " in",
                " a",
                " city",
                '.",',
                ' "',
                "parameters",
                '":',
                ' {"',
                "type",
                '":',
                ' "',
                "object",
                '",',
                ' "',
                "properties",
                '":',
                ' {"',
                "city",
                '":',
                ' {"',
                "type",
                '":',
                ' "',
                "string",
                '",',
                ' "',
                "description",
                '":',
                ' "',
                "The",
                " city",
                " name",
                '"',
                "}},",
                ' "',
                "required",
                '":',
                ' ["',
                "city",
                '"]',
                "}}",
                "}]",
                "[/AVAILABLE_TOOLS]",
                "[INST]",
                "Hello",
                " world",
                " !",
                "[/INST]",
                "[TOOL_CALLS]",
                "get",
                "_",
                "weather",
                "[ARGS]",
                '{"',
                "city",
                '":',
                ' "',
                "Paris",
                '"}',
                "</s>",
                "[TOOL_RESULTS]",
                '{"',
                "temperature",
                '":',
                " ",
                "2",
                "0",
                ",",
                ' "',
                "unit",
                '":',
                ' "',
                "c",
                "elsius",
                '"}',
                "[/TOOL_RESULTS]",
            ],
        )

        expected_strings = (
            '[{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}] I am an AI\n\nHello world ![TOOL_CALLS][{"name": "get_weather", "arguments": {"city": "Paris"}, "id": "123456789"}] {"content": {"temperature": 20, "unit": "celsius"}, "call_id": "123456789"}',  # noqa: E501
            'I am an AI[{"type": "function", "function": {"name": "get_weather", "description": "Gets the current weather in a city.", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name"}}, "required": ["city"]}}}]Hello world ![TOOL_CALLS]get_weather{"city": "Paris"}{"temperature": 20, "unit": "celsius"}',  # noqa: E501
        )

        assert (
            mistral_tokenizer.convert_tokens_to_string(
                tokens[mistral_tokenizer.is_tekken]
            )
            == expected_strings[mistral_tokenizer.is_tekken]
        )

    @pytest.mark.parametrize(
        "skip_special_tokens,tuple_expected_tokens",
        (
            (
                True,
                (
                    [
                        "▁[",
                        '{"',
                        "type",
                        '":',
                        '▁"',
                        "function",
                        '",',
                        '▁"',
                        "function",
                        '":',
                        '▁{"',
                        "name",
                        '":',
                        '▁"',
                        "get",
                        "_",
                        "we",
                        "ather",
                        '",',
                        '▁"',
                        "description",
                        '":',
                        '▁"',
                        "Get",
                        "s",
                        "▁the",
                        "▁current",
                        "▁weather",
                        "▁in",
                        "▁a",
                        "▁city",
                        '.",',
                        '▁"',
                        "parameters",
                        '":',
                        '▁{"',
                        "type",
                        '":',
                        '▁"',
                        "object",
                        '",',
                        '▁"',
                        "properties",
                        '":',
                        '▁{"',
                        "city",
                        '":',
                        '▁{"',
                        "type",
                        '":',
                        '▁"',
                        "string",
                        '",',
                        '▁"',
                        "description",
                        '":',
                        '▁"',
                        "The",
                        "▁city",
                        "▁name",
                        '"',
                        "}},",
                        '▁"',
                        "required",
                        '":',
                        '▁["',
                        "city",
                        '"]',
                        "}}",
                        "}]",
                        "▁I",
                        "▁am",
                        "▁an",
                        "▁AI",
                        "<0x0A>",
                        "<0x0A>",
                        "Hello",
                        "▁world",
                        "▁!",
                        "[TOOL_CALLS]",
                        "▁[",
                        '{"',
                        "name",
                        '":',
                        '▁"',
                        "get",
                        "_",
                        "we",
                        "ather",
                        '",',
                        '▁"',
                        "arguments",
                        '":',
                        '▁{"',
                        "city",
                        '":',
                        '▁"',
                        "Par",
                        "is",
                        '"},',
                        '▁"',
                        "id",
                        '":',
                        '▁"',
                        "1",
                        "2",
                        "3",
                        "4",
                        "5",
                        "6",
                        "7",
                        "8",
                        "9",
                        '"',
                        "}]",
                        '▁{"',
                        "content",
                        '":',
                        '▁{"',
                        "t",
                        "emperature",
                        '":',
                        "▁",
                        "2",
                        "0",
                        ",",
                        '▁"',
                        "unit",
                        '":',
                        '▁"',
                        "c",
                        "els",
                        "ius",
                        '"},',
                        '▁"',
                        "call",
                        "_",
                        "id",
                        '":',
                        '▁"',
                        "1",
                        "2",
                        "3",
                        "4",
                        "5",
                        "6",
                        "7",
                        "8",
                        "9",
                        '"}',
                    ],
                    [
                        "I",
                        " am",
                        " an",
                        " AI",
                        "[",
                        '{"',
                        "type",
                        '":',
                        ' "',
                        "function",
                        '",',
                        ' "',
                        "function",
                        '":',
                        ' {"',
                        "name",
                        '":',
                        ' "',
                        "get",
                        "_",
                        "weather",
                        '",',
                        ' "',
                        "description",
                        '":',
                        ' "',
                        "G",
                        "ets",
                        " the",
                        " current",
                        " weather",
                        " in",
                        " a",
                        " city",
                        '.",',
                        ' "',
                        "parameters",
                        '":',
                        ' {"',
                        "type",
                        '":',
                        ' "',
                        "object",
                        '",',
                        ' "',
                        "properties",
                        '":',
                        ' {"',
                        "city",
                        '":',
                        ' {"',
                        "type",
                        '":',
                        ' "',
                        "string",
                        '",',
                        ' "',
                        "description",
                        '":',
                        ' "',
                        "The",
                        " city",
                        " name",
                        '"',
                        "}},",
                        ' "',
                        "required",
                        '":',
                        ' ["',
                        "city",
                        '"]',
                        "}}",
                        "}]",
                        "Hello",
                        " world",
                        " !",
                        "[TOOL_CALLS]",
                        "get",
                        "_",
                        "weather",
                        '{"',
                        "city",
                        '":',
                        ' "',
                        "Paris",
                        '"}',
                        '{"',
                        "temperature",
                        '":',
                        " ",
                        "2",
                        "0",
                        ",",
                        ' "',
                        "unit",
                        '":',
                        ' "',
                        "c",
                        "elsius",
                        '"}',
                    ],
                ),
            ),
            (
                False,
                (
                    [
                        "<s>",
                        "[AVAILABLE_TOOLS]",
                        "▁[",
                        '{"',
                        "type",
                        '":',
                        '▁"',
                        "function",
                        '",',
                        '▁"',
                        "function",
                        '":',
                        '▁{"',
                        "name",
                        '":',
                        '▁"',
                        "get",
                        "_",
                        "we",
                        "ather",
                        '",',
                        '▁"',
                        "description",
                        '":',
                        '▁"',
                        "Get",
                        "s",
                        "▁the",
                        "▁current",
                        "▁weather",
                        "▁in",
                        "▁a",
                        "▁city",
                        '.",',
                        '▁"',
                        "parameters",
                        '":',
                        '▁{"',
                        "type",
                        '":',
                        '▁"',
                        "object",
                        '",',
                        '▁"',
                        "properties",
                        '":',
                        '▁{"',
                        "city",
                        '":',
                        '▁{"',
                        "type",
                        '":',
                        '▁"',
                        "string",
                        '",',
                        '▁"',
                        "description",
                        '":',
                        '▁"',
                        "The",
                        "▁city",
                        "▁name",
                        '"',
                        "}},",
                        '▁"',
                        "required",
                        '":',
                        '▁["',
                        "city",
                        '"]',
                        "}}",
                        "}]",
                        "[/AVAILABLE_TOOLS]",
                        "[INST]",
                        "▁I",
                        "▁am",
                        "▁an",
                        "▁AI",
                        "<0x0A>",
                        "<0x0A>",
                        "Hello",
                        "▁world",
                        "▁!",
                        "[/INST]",
                        "[TOOL_CALLS]",
                        "▁[",
                        '{"',
                        "name",
                        '":',
                        '▁"',
                        "get",
                        "_",
                        "we",
                        "ather",
                        '",',
                        '▁"',
                        "arguments",
                        '":',
                        '▁{"',
                        "city",
                        '":',
                        '▁"',
                        "Par",
                        "is",
                        '"},',
                        '▁"',
                        "id",
                        '":',
                        '▁"',
                        "1",
                        "2",
                        "3",
                        "4",
                        "5",
                        "6",
                        "7",
                        "8",
                        "9",
                        '"',
                        "}]",
                        "</s>",
                        "[TOOL_RESULTS]",
                        '▁{"',
                        "content",
                        '":',
                        '▁{"',
                        "t",
                        "emperature",
                        '":',
                        "▁",
                        "2",
                        "0",
                        ",",
                        '▁"',
                        "unit",
                        '":',
                        '▁"',
                        "c",
                        "els",
                        "ius",
                        '"},',
                        '▁"',
                        "call",
                        "_",
                        "id",
                        '":',
                        '▁"',
                        "1",
                        "2",
                        "3",
                        "4",
                        "5",
                        "6",
                        "7",
                        "8",
                        "9",
                        '"}',
                        "[/TOOL_RESULTS]",
                    ],
                    [
                        "<s>",
                        "[SYSTEM_PROMPT]",
                        "I",
                        " am",
                        " an",
                        " AI",
                        "[/SYSTEM_PROMPT]",
                        "[AVAILABLE_TOOLS]",
                        "[",
                        '{"',
                        "type",
                        '":',
                        ' "',
                        "function",
                        '",',
                        ' "',
                        "function",
                        '":',
                        ' {"',
                        "name",
                        '":',
                        ' "',
                        "get",
                        "_",
                        "weather",
                        '",',
                        ' "',
                        "description",
                        '":',
                        ' "',
                        "G",
                        "ets",
                        " the",
                        " current",
                        " weather",
                        " in",
                        " a",
                        " city",
                        '.",',
                        ' "',
                        "parameters",
                        '":',
                        ' {"',
                        "type",
                        '":',
                        ' "',
                        "object",
                        '",',
                        ' "',
                        "properties",
                        '":',
                        ' {"',
                        "city",
                        '":',
                        ' {"',
                        "type",
                        '":',
                        ' "',
                        "string",
                        '",',
                        ' "',
                        "description",
                        '":',
                        ' "',
                        "The",
                        " city",
                        " name",
                        '"',
                        "}},",
                        ' "',
                        "required",
                        '":',
                        ' ["',
                        "city",
                        '"]',
                        "}}",
                        "}]",
                        "[/AVAILABLE_TOOLS]",
                        "[INST]",
                        "Hello",
                        " world",
                        " !",
                        "[/INST]",
                        "[TOOL_CALLS]",
                        "get",
                        "_",
                        "weather",
                        "[ARGS]",
                        '{"',
                        "city",
                        '":',
                        ' "',
                        "Paris",
                        '"}',
                        "</s>",
                        "[TOOL_RESULTS]",
                        '{"',
                        "temperature",
                        '":',
                        " ",
                        "2",
                        "0",
                        ",",
                        ' "',
                        "unit",
                        '":',
                        ' "',
                        "c",
                        "elsius",
                        '"}',
                        "[/TOOL_RESULTS]",
                    ],
                ),
            ),
        ),
    )
    def test_convert_ids_to_tokens(
        self,
        mistral_tokenizer: MistralTokenizer,
        skip_special_tokens: bool,
        tuple_expected_tokens: tuple[list[str], list[str]],
    ):
        tuple_ids = (
            [
                1,
                6,
                1501,
                7567,
                1891,
                2032,
                1113,
                3396,
                1316,
                1113,
                3396,
                2032,
                10598,
                1629,
                2032,
                1113,
                1295,
                29498,
                1537,
                1991,
                1316,
                1113,
                7286,
                2032,
                1113,
                2226,
                29481,
                1040,
                2636,
                8854,
                1065,
                1032,
                3758,
                9959,
                1113,
                12206,
                2032,
                10598,
                1891,
                2032,
                1113,
                3582,
                1316,
                1113,
                11491,
                2032,
                10598,
                19141,
                2032,
                10598,
                1891,
                2032,
                1113,
                2195,
                1316,
                1113,
                7286,
                2032,
                1113,
                1782,
                3758,
                1909,
                29507,
                11549,
                1113,
                11661,
                2032,
                8135,
                19141,
                3010,
                1743,
                10925,
                7,
                3,
                1083,
                1605,
                1164,
                16875,
                781,
                781,
                16998,
                2294,
                1686,
                4,
                5,
                1501,
                7567,
                1629,
                2032,
                1113,
                1295,
                29498,
                1537,
                1991,
                1316,
                1113,
                17452,
                2032,
                10598,
                19141,
                2032,
                1113,
                4684,
                1046,
                8474,
                1113,
                1081,
                2032,
                1113,
                29508,
                29518,
                29538,
                29549,
                29550,
                29552,
                29555,
                29551,
                29542,
                29507,
                10925,
                2,
                8,
                10598,
                4557,
                2032,
                10598,
                29475,
                17329,
                2032,
                29473,
                29518,
                29502,
                29493,
                1113,
                6074,
                2032,
                1113,
                29485,
                1958,
                3938,
                8474,
                1113,
                3613,
                29498,
                1081,
                2032,
                1113,
                29508,
                29518,
                29538,
                29549,
                29550,
                29552,
                29555,
                29551,
                29542,
                18163,
                9,
            ],
            [
                1,
                17,
                1073,
                1855,
                1420,
                26554,
                18,
                5,
                1091,
                19227,
                4994,
                2811,
                1429,
                5165,
                1897,
                1429,
                5165,
                2811,
                16753,
                2391,
                2811,
                1429,
                1689,
                1095,
                45629,
                1897,
                1429,
                14653,
                2811,
                1429,
                1071,
                3083,
                1278,
                3519,
                17253,
                1294,
                1261,
                5970,
                39249,
                1429,
                26204,
                2811,
                16753,
                4994,
                2811,
                1429,
                6371,
                1897,
                1429,
                48649,
                2811,
                16753,
                29363,
                2811,
                16753,
                4994,
                2811,
                1429,
                3607,
                1897,
                1429,
                14653,
                2811,
                1429,
                1784,
                5970,
                2564,
                1034,
                47579,
                1429,
                15760,
                2811,
                12161,
                29363,
                4964,
                2821,
                27028,
                6,
                3,
                22177,
                4304,
                2662,
                4,
                9,
                1689,
                1095,
                45629,
                32,
                19227,
                29363,
                2811,
                1429,
                42572,
                46005,
                2,
                7,
                19227,
                113824,
                2811,
                1032,
                1050,
                1048,
                1044,
                1429,
                8979,
                2811,
                1429,
                1099,
                79092,
                46005,
                8,
            ],
        )

        ids = tuple_ids[mistral_tokenizer.is_tekken]
        expected_tokens = tuple_expected_tokens[mistral_tokenizer.is_tekken]
        actual_tokens = mistral_tokenizer.convert_ids_to_tokens(
            ids, skip_special_tokens=skip_special_tokens
        )
        assert actual_tokens == expected_tokens
